import numpy as np
import torch
import torch.nn as nn
from onpolicy.utils.util import get_gard_norm, huber_loss, mse_loss
from onpolicy.utils.valuenorm import ValueNorm
from onpolicy.algorithms.utils.util import check
from onpolicy.algorithms.utils.rl_utils import build_td_lambda_targets_with_weights
import pdb

import torch as th
import torch.nn.functional as F
from torch.optim import RMSprop, Adam



class R_MAPPO():
    """
    Trainer class for MAPPO to update policies.
    :param args: (argparse.Namespace) arguments containing relevant model, policy, and env information.
    :param policy: (R_MAPPO_Policy) policy to update.
    :param device: (torch.device) specifies the device to run on (cpu/gpu).
    """
    def __init__(self,
                 args,
                 policy,
                 device=torch.device("cpu")):
        self.args = args
        self.device = device
        self.tpdv = dict(dtype=torch.float32, device=device)
        self.policy = policy

        self.clip_param = args.clip_param
        self.ppo_epoch = args.ppo_epoch
        self.num_mini_batch = args.num_mini_batch
        self.data_chunk_length = args.data_chunk_length
        self.value_loss_coef = args.value_loss_coef
        self.entropy_coef = args.entropy_coef
        self.max_grad_norm = args.max_grad_norm       
        self.huber_delta = args.huber_delta

        self._use_recurrent_policy = args.use_recurrent_policy
        self._use_naive_recurrent = args.use_naive_recurrent_policy
        self._use_max_grad_norm = args.use_max_grad_norm
        self._use_clipped_value_loss = args.use_clipped_value_loss
        self._use_huber_loss = args.use_huber_loss
        self._use_popart = args.use_popart
        self._use_valuenorm = args.use_valuenorm
        self._use_value_active_masks = args.use_value_active_masks
        self._use_policy_active_masks = args.use_policy_active_masks
        
        assert (self._use_popart and self._use_valuenorm) == False, ("self._use_popart and self._use_valuenorm can not be set True simultaneously")
        
        if self._use_popart:
            self.value_normalizer = self.policy.critic.v_out
        elif self._use_valuenorm:
            self.value_normalizer = ValueNorm(1).to(self.device)
        else:
            self.value_normalizer = None




    def set_lfiw(self, lfiw):
        self.lfiw = lfiw.cuda()
        self.lfiw_optimiser = Adam(params=lfiw.parameters(), lr=self.args.lfiw_lr)
        self.prob_temperature = self.args.prob_temperature

    def cal_value_loss(self, values, value_preds_batch, return_batch, active_masks_batch):
        """
        Calculate value function loss.
        :param values: (torch.Tensor) value function predictions.
        :param value_preds_batch: (torch.Tensor) "old" value  predictions from data batch (used for value clip loss)
        :param return_batch: (torch.Tensor) reward to go returns.
        :param active_masks_batch: (torch.Tensor) denotes if agent is active or dead at a given timesep.

        :return value_loss: (torch.Tensor) value function loss.
        """

        value_pred_clipped = value_preds_batch + (values - value_preds_batch).clamp(-self.clip_param, self.clip_param)
        if self._use_popart or self._use_valuenorm:
            self.value_normalizer.update(return_batch)
            error_clipped = self.value_normalizer.normalize(return_batch) - value_pred_clipped
            error_original = self.value_normalizer.normalize(return_batch) - values
        else:
            error_clipped = return_batch - value_pred_clipped
            error_original = return_batch - values

        if self._use_huber_loss:
            value_loss_clipped = huber_loss(error_clipped, self.huber_delta)
            value_loss_original = huber_loss(error_original, self.huber_delta)
        else:
            value_loss_clipped = mse_loss(error_clipped)
            value_loss_original = mse_loss(error_original)

        if self._use_clipped_value_loss:
            value_loss = torch.max(value_loss_original, value_loss_clipped)
        else:
            value_loss = value_loss_original

        if self._use_value_active_masks:
            value_loss = (value_loss * active_masks_batch).sum() / active_masks_batch.sum()
        else:
            value_loss = value_loss.mean()

        return value_loss

    def ppo_update(self, sample, update_actor=True):
        """
        Update actor and critic networks.
        :param sample: (Tuple) contains data batch with which to update networks.
        :update_actor: (bool) whether to update actor network.

        :return value_loss: (torch.Tensor) value function loss.
        :return critic_grad_norm: (torch.Tensor) gradient norm from critic up9date.
        ;return policy_loss: (torch.Tensor) actor(policy) loss value.
        :return dist_entropy: (torch.Tensor) action entropies.
        :return actor_grad_norm: (torch.Tensor) gradient norm from actor update.
        :return imp_weights: (torch.Tensor) importance sampling weights.
        """
        share_obs_batch, obs_batch, rnn_states_batch, rnn_states_critic_batch, actions_batch, \
        value_preds_batch, return_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, \
        adv_targ, available_actions_batch = sample

        old_action_log_probs_batch = check(old_action_log_probs_batch).to(**self.tpdv)
        adv_targ = check(adv_targ).to(**self.tpdv)
        value_preds_batch = check(value_preds_batch).to(**self.tpdv)
        return_batch = check(return_batch).to(**self.tpdv)
        active_masks_batch = check(active_masks_batch).to(**self.tpdv)

        # Reshape to do in a single forward pass for all steps
        values, action_log_probs, dist_entropy = self.policy.evaluate_actions(share_obs_batch,
                                                                              obs_batch, 
                                                                              rnn_states_batch, 
                                                                              rnn_states_critic_batch, 
                                                                              actions_batch, 
                                                                              masks_batch, 
                                                                              available_actions_batch,
                                                                              active_masks_batch)
        # actor update
        imp_weights = torch.exp(action_log_probs - old_action_log_probs_batch)

        surr1 = imp_weights * adv_targ
        surr2 = torch.clamp(imp_weights, 1.0 - self.clip_param, 1.0 + self.clip_param) * adv_targ

        if self._use_policy_active_masks:
            policy_action_loss = (-torch.sum(torch.min(surr1, surr2),
                                             dim=-1,
                                             keepdim=True) * active_masks_batch).sum() / active_masks_batch.sum()
        else:
            policy_action_loss = -torch.sum(torch.min(surr1, surr2), dim=-1, keepdim=True).mean()

        policy_loss = policy_action_loss

        self.policy.actor_optimizer.zero_grad()

        if update_actor:
            (policy_loss - dist_entropy * self.entropy_coef).backward()

        if self._use_max_grad_norm:
            actor_grad_norm = nn.utils.clip_grad_norm_(self.policy.actor.parameters(), self.max_grad_norm)
        else:
            actor_grad_norm = get_gard_norm(self.policy.actor.parameters())

        self.policy.actor_optimizer.step()

        # critic update
        value_loss = self.cal_value_loss(values, value_preds_batch, return_batch, active_masks_batch)

        self.policy.critic_optimizer.zero_grad()

        (value_loss * self.value_loss_coef).backward()

        if self._use_max_grad_norm:
            critic_grad_norm = nn.utils.clip_grad_norm_(self.policy.critic.parameters(), self.max_grad_norm)
        else:
            critic_grad_norm = get_gard_norm(self.policy.critic.parameters())

        self.policy.critic_optimizer.step()

        

        return value_loss, critic_grad_norm, policy_loss, dist_entropy, actor_grad_norm, imp_weights



    def train(self, buffer, slow_samples, update_actor=True):
        """
        Perform a training update using minibatch GD.
        :param buffer: (SharedReplayBuffer) buffer containing training data.
        :param update_actor: (bool) whether to update actor network.

        :return train_info: (dict) contains information regarding training update (e.g. loss, grad norms, etc).
        """
        if self._use_popart or self._use_valuenorm:
            advantages = buffer.returns[:-1] - self.value_normalizer.denormalize(buffer.value_preds[:-1])
        else:
            advantages = buffer.returns[:-1] - buffer.value_preds[:-1]
        advantages_copy = advantages.copy()
        advantages_copy[buffer.active_masks[:-1] == 0.0] = np.nan
        mean_advantages = np.nanmean(advantages_copy)
        std_advantages = np.nanstd(advantages_copy)
        advantages = (advantages - mean_advantages) / (std_advantages + 1e-5)
        

        train_info = {}

        train_info['value_loss'] = 0
        train_info['policy_loss'] = 0
        train_info['dist_entropy'] = 0
        train_info['actor_grad_norm'] = 0
        train_info['critic_grad_norm'] = 0
        train_info['ratio'] = 0

        for _ in range(self.ppo_epoch):
            if self._use_recurrent_policy:
                data_generator = buffer.recurrent_generator(advantages, self.num_mini_batch, self.data_chunk_length)
            elif self._use_naive_recurrent:
                data_generator = buffer.naive_recurrent_generator(advantages, self.num_mini_batch)
            else:
                data_generator = buffer.feed_forward_generator(advantages, self.num_mini_batch)

            for sample in data_generator:

                value_loss, critic_grad_norm, policy_loss, dist_entropy, actor_grad_norm, imp_weights \
                    = self.ppo_update(sample, update_actor)

                train_info['value_loss'] += value_loss.item()
                train_info['policy_loss'] += policy_loss.item()
                train_info['dist_entropy'] += dist_entropy.item()
                train_info['actor_grad_norm'] += actor_grad_norm
                train_info['critic_grad_norm'] += critic_grad_norm
                train_info['ratio'] += imp_weights.mean()
        
        # Additional critic updates
        if self.args.dyn_lambda:

            s_states, s_rewards, s_terminated, s_obs, s_value_preds, s_actions_env, s_share_obs, s_rnn_states_critic, s_masks = slow_samples
            input_share_obs = s_share_obs.reshape(-1, s_share_obs.shape[-1])
            input_rnn_states_critic = s_rnn_states_critic.reshape(-1, 1, s_rnn_states_critic.shape[-1])
            input_masks = s_masks.reshape(-1, s_masks.shape[-1])
            out_critic = self.policy.get_values(input_share_obs, input_rnn_states_critic, input_masks).reshape(*(s_value_preds.shape))


            s_terminated = 1 - s_terminated
            ret = s_terminated[:, 0]
            for t in range(1, s_terminated.shape[1]):
                ret = ret * s_terminated[:,t]
                s_terminated[:, t] = ret
            s_terminated = 1 - s_terminated
            
            s_masks = 1 - s_terminated      

            td_lambda = self.lambda_weights.clone().detach()


            td_lambda = np.expand_dims(np.repeat(td_lambda.cpu().numpy(), self.args.num_agents, axis=-1), axis=-1)
            s_terminated = np.expand_dims(np.repeat(s_terminated, self.args.num_agents, axis=-1), axis=-1)
            s_masks = np.expand_dims(np.repeat(s_masks, self.args.num_agents, axis=-1), axis=-1)
            s_rewards = np.expand_dims(np.repeat(s_rewards, self.args.num_agents, axis=-1), axis=-1)


            targets = s_rewards + 0.99 * (1 - s_terminated[:,:-1]) * s_value_preds[:, 1:]
            delta_td = (out_critic[:,:-1].clone().cpu().detach().numpy() - targets)
            
            delta = 1 - np.exp(-np.abs(delta_td))

            delta = delta / np.mean(delta)

            td_lambda = np.clip(td_lambda * delta, a_min=0.0, a_max=1.0)
            #new_weights = lambda_weights * delta.unsqueeze(dim=-1).detach()
            '''
            target = build_td_lambda_targets_with_weights(s_rewards, s_terminated, s_masks, s_value_preds, 0.99, td_lambda)

            loss = (self.MSE_weights * (th.from_numpy(target).cuda()[:,:-1] - out_critic[:,:-1]) ** 2).sum() / s_masks.sum()

            self.policy.critic_optimizer.zero_grad()
            loss.backward()
            grad_norm = th.nn.utils.clip_grad_norm_(self.policy.critic.parameters(), 10)
            self.policy.critic_optimizer.step()
            '''
            returns = self.calculate_return_dyn_gae(s_value_preds, s_rewards, s_masks, td_lambda)

            out_critic = out_critic.reshape(-1, 1)
            s_value_preds = th.from_numpy(s_value_preds).cuda().reshape(-1, 1)
            returns = th.from_numpy(returns).cuda().reshape(-1, 1)
            s_masks = th.from_numpy(s_masks).cuda().reshape(-1, 1)

            value_loss = self.cal_value_loss(out_critic, s_value_preds, returns, s_masks)

            self.policy.critic_optimizer.zero_grad()

            (value_loss * self.value_loss_coef).backward()

            if self._use_max_grad_norm: 
                critic_grad_norm = nn.utils.clip_grad_norm_(self.policy.critic.parameters(), self.max_grad_norm)
            else:
                critic_grad_norm = get_gard_norm(self.policy.critic.parameters())

            self.policy.critic_optimizer.step()

        num_updates = self.ppo_epoch * self.num_mini_batch

        for k in train_info.keys():
            train_info[k] /= num_updates
 
        return train_info

    def calculate_return_dyn_gae(self, value_preds, rewards, masks, gae_lambda):
        v_preds = value_preds.copy()
        returns = np.zeros_like(value_preds)

        gae = 0
        for step in reversed(range(rewards.shape[0])):
            if self._use_popart or self._use_valuenorm:
                delta = rewards[:, step] + 0.99 * self.value_normalizer.denormalize(
                    v_preds[:, step + 1]) * masks[:, step + 1] - self.value_normalizer.denormalize(v_preds[:, step])
                gae = delta + 0.99 * gae_lambda[:, step] * masks[:, step + 1] * gae
                returns[:, step] = gae + self.value_normalizer.denormalize(v_preds[:, step])
            else:
                delta = rewards[:, step] + 0.99 * v_preds[:, step + 1] * masks[:, step + 1] - \
                        v_preds[:, step]
                gae = delta + 0.99 * gae_lambda[:, step] * masks[:, step + 1] * gae
                returns[:, step] = gae + v_preds[:, step]
        return returns

    def prep_training(self):
        self.policy.actor.train()
        self.policy.critic.train()

    def prep_rollout(self):
        self.policy.actor.eval()
        self.policy.critic.eval()

    def update_lfiw(self, fast_batch, slow_batch):

        f_states, f_rewards, f_terminated, f_obs, f_value_preds, f_actions_env, f_share_obs, f_rnn_states_critic, f_masks = fast_batch
        s_states, s_rewards, s_terminated, s_obs, s_value_preds, s_actions_env, s_share_obs, s_rnn_states_critic, s_masks = slow_batch

        batch_size = f_states.shape[0]
        max_seq_length = f_states.shape[1] - 1

        slow_outs = []
        fast_outs = []
        self.LFIW_init_hidden(batch_size)


        for t in range(max_seq_length):
            slow_out = self.LFIW_forward(s_states, s_actions_env, t, batch_size).reshape(-1, 1)       # [bs * T, 1]
            slow_outs.append(slow_out)
        
        slow_outs = th.stack(slow_outs, dim=1)

        self.LFIW_init_hidden(batch_size)
        for t in range(max_seq_length):
            fast_out = self.LFIW_forward(f_states, f_actions_env, t, batch_size).reshape(-1, 1)
            fast_outs.append(fast_out)

        fast_outs = th.stack(fast_outs, dim=1)

        zeros = th.zeros_like(slow_outs).to('cuda')
        ones = th.ones_like(fast_outs).to('cuda')

        loss = F.binary_cross_entropy(F.sigmoid(slow_outs), zeros) + F.binary_cross_entropy(F.sigmoid(fast_outs), ones)
        

        self.lfiw_optimiser.zero_grad()
        loss.backward(retain_graph=False)
        self.lfiw_optimiser.step()

        self.LFIW_init_hidden(batch_size)
        slow_preds = []
        for t in range(max_seq_length):
            slow_pred = self.LFIW_forward(s_states, s_actions_env, t, batch_size).reshape(-1, 1)
            slow_preds.append(slow_pred)
        slow_preds = th.stack(slow_preds, dim=1)

        importance_weights = F.sigmoid(slow_preds/self.prob_temperature).detach()
        self.lambda_weights = importance_weights.clone()

        #lambda_weights = (lambda_weights - self.min_weight) / (self.max_weight - self.min_weight)
        #lambda_weights = th.minimum(th.maximum((lambda_weights / th.mean(lambda_weights) - 1), th.zeros_like(lambda_weights)), th.ones_like(lambda_weights))
        #lambda_weights = th.nn.functional.softmax(lambda_weights / th.mean(lambda_weights) - 1, dim=1)

        #print(th.max(lambda_weights), th.min(lambda_weights), th.median(lambda_weights))
         
        MSE_weights = F.sigmoid(slow_preds/1.0).detach()
        self.MSE_weights = MSE_weights / th.mean(MSE_weights)
        self.MSE_weights = th.stack([self.MSE_weights for _ in range(s_share_obs.shape[2])], dim=2)
        

    def LFIW_init_hidden(self, batch_size):
        self.LFIW_hidden_states = self.lfiw.init_hidden().expand(batch_size, -1).to('cuda')

    def LFIW_forward(self, states, actions_env, t, batch_size):
        
        inputs = []
        inputs.append(states[:, t])
        inputs.append(actions_env[:, t].reshape(batch_size, -1))

        inputs = th.cat([th.tensor(x) for x in inputs], dim=1).float().to('cuda')

        
        outs, self.LFIW_hidden_states = self.lfiw(inputs, self.LFIW_hidden_states)
        return outs.view(batch_size, -1)